Analysis and figures for Mike's AGU talk

Basic setup

All of the models below are trained with the lowest 21 levels of CAM and use the following inputs/features to predict the following outputs/targets:

Features: [TAP, QAP, dTdt_adiabatic, dQdt_adiabatic, SHFLX, LHFLX, SOLIN], where the first 4 variables are are 3D fields (so 21 level) and the last 3 are 2D fields with 1 level. For the network they were stacked to create one vector with length 87 and then normalized (subtract mean, divide by std)

Targets: [SPDT, SPDQ, QRL, QRS, PRECT, FLUT]. The first 4 are 3D, the last 2 are 2D. Again the variables were stacked to create a vector with length 86. Here we rescaled the variables to be roughly the same order of magnitude (however, this might not be so cleverly done, see analysis below). The rescaling factors for the variables are [C_P, L_V, C_P, C_P, 1e3*24*3600 * 1e-3, 1e5].

Train/Validation split: The data for training and testing were split by time. We have 2+ years of Aquaplanet simulations. For validation we are always using the entirety of year 2. For training, we are using subsections of year 1 (to test training length sensitivity). For training the samples (which are flattened from time, lat, lon) are randomly shuffled. This actually brings a large improvement in score (thanks to Gregor Orban for the tip!). 1 year of data equals around 143 million samples.

In [383]:
# Imports
from importlib import reload
import utils; reload(utils); from utils import *
import tensorflow as tf
import keras
import netCDF4 as nc
import h5py
sns.set_style('dark')
sns.set_context('talk')
%matplotlib inline
import pandas as pd
from IPython.display import SVG, HTML
from keras.utils.vis_utils import model_to_dot

# Setup
config = tf.ConfigProto()
config.gpu_options.allow_growth = True   # Allocates as much memory as needed.
keras.backend.tensorflow_backend.set_session(tf.Session(config=config))
sns.set_style("white")
from matplotlib.animation import FuncAnimation 
plt.rcParams['savefig.facecolor']='black'
plt.rcParams['text.color'] = 'white'
In [2]:
data_dir = '/local/S.Rasp/cbrain_data/'
In [31]:
feature_names = ['TAP', 'QAP', 'dTdt_adiabatic', 'dQdt_adiabatic', 'SHFLX', 'LHFLX', 'SOLIN']
target_names = ['SPDT', 'SPDQ', 'QRL', 'QRS', 'PRECT', 'FLUT']

Neural network cheat sheet

Just a brain dump from my side with everything someone might ask.

  • batch size = 1024
  • activation function: relu = max(0, x)
  • training time for "best" network (1024-1024-512-512) on a single Nvidia GPU: 20 hours
  • Why use powers of two for batch size and number of hidden nodes? Because of the way multiprocessing is handled on a GPU.
  • Number of training (and also validation) samples: approx 143 million
  • Size of training data: approx 100GB
  • Optimizer: Adam, which is an improved version of classing stochastic gradient descent (using momentum and adaptable learning rates for each parameter)
  • Learning rate: starting with 1e-3, halved every 5 epochs
  • Epochs: 30
  • Keras: A high-level neural network library for Python, running on top of Tensorflow, a dedicated neural network programming language developed by Google.

Scores/metrics and comparison with TF version

Let's start out by looking a little at which score to use to compare the different architectures and training lengths. This issue actually causes me quite a bit of grievance. Here it goes:

  • The way the scores are computed in Keras vs. TF are slightly different (e.g. the order in which the scores are averaged in z and sample direction). And some metrics (R2 and I think also the log loss as it is in TF) depend on the batch size which is different for mine and Pierre's runs. This means right now the scores are unfortunately not comparable :(
  • Additionally, Pierre is predicting SPDT, SPDQ separately from QRS, QRL. I think it will be interesting to try out how the fits respond to that in the future. But again, this makes an immediate comparison impossible
  • Despite our rescaling, the magnitudes of the target variables are not really similar (see plot below). The fits are still ok as you can see below. This is definitely something to go back to (output normalization?). But this also means that many of the average scores weigh some variables more than others.

This last issue makes it hard to compare different experiments with one score, but for the sensitivities below, I want to do it anyway. What gives me confidence is that the order of the experiments is similar for different scores. The score I will use is the mean explained variance. This I will define as $$ \frac{1}{N_z}\sum_z 1 - (\mathrm{squared\,error}(z) / \mathrm{target\,variance}(z)),$$ where $z$ is the stacked vertical dimension, so 21 levels each of SPDT, SPDQ, QRL and QRS plus one level each for precipitation and OLR.

Later on we will look at the individual variables.

In [41]:
# Load means and stds
norm = nc.Dataset(data_dir + 'year1_norm.nc')
In [6]:
plt.plot(norm.variables['target_means'], label='mean')
plt.plot(norm.variables['target_stds'], label='std')
plt.legend()
plt.show()

The variables in order are SPDT, SPDQ, QRL, QRS, PRECT, FLUT. Lower levels are further to the right. This basically means, that the lower levels of SPDQ are much larger in magnitude than some of the others.

Sensitivity to network architecture

Here I will just compare a few model architectures. All of them were trained with a training set containing day 1-3 from each of the 12 months in year 1 of the data.

In [7]:
model_names = [
    '005_1month_distr_shuffle',   # Base model: 1024, 1024, 512, 512 --> 1,971,286 parameters
    '011_1month_log_loss',        # Same model, but with log_loss minimization
    '021_1month_miniscule',       # Dense, one hidden layer 256 nodes --> 44,630 parameters
    '016_1month_tiny',            # One hidden layer 512 nodes
    '017_1month_small',           # 512-512
    '018_1month_medium',          # 1024, 1024
    '022_1month_narrow_deep',     # 128, 128, 128, 128
    '023_1month_broad_shallow',   # 2048
    '019_1month_conv',            # Convolution with 6x32 + 1x6 + flatten + dense linear
    '020_1month_local_conv',      # Local convolution with 2x20 + 1x6 + flatten + dense linear
]
In [8]:
# # Compute scores for all models, takes a long time. Results are saved!
# score_list = []
# for model in model_names:
#     score_list.append(run_diagnostics(
#     './models/' + model + '.h5', 
#     data_dir,
#     'valid_year1', 
#     data_dir + 'year1_norm.nc',
#     convo=True if 'conv' in model else False,
# ))
In [9]:
# df = pd.DataFrame({
#     'Experiment': [n[4:] for n in model_names], 
#     'Explained variance': [np.mean(1-v[0]) for v in score_list],
# })
# df.to_csv('./network_sensitivity.csv')
In [10]:
# # Go through all models and save number of parameters
# param_list = []
# for model in model_names:
#     m = keras.models.load_model('./models/' + model + '.h5')
#     param_list.append(m.count_params())
In [11]:
# df['Parameters'] = param_list
In [12]:
# df.to_csv('./network_sensitivity.csv')
In [13]:
df = pd.read_csv('./network_sensitivity.csv')
In [380]:
sns.barplot(y='Experiment', x='Explained variance', data=df, palette='hls')
plt.tight_layout()
In [381]:
sns.barplot(y='Experiment', x='Parameters', data=df, palette='hls')
plt.xscale('log')

The best network is the default 1024-1024-512-512 dense model. After that the results are confusing because the fits do not scale with the number of parameters. For example, a single layer network with 256 nodes (miniscule) performs much better than a two layer network with 512-512 nodes. In general, it seems that shallow wide networks are better than narrow deep.

My suggestions for the talk

For the AGU talk maybe a figure plotting the number of parameters against the score would be good. See a first draft below. If you think this is the best way of displaying this sensitivity, I would color-code and annotate the figure highlighting the different architectures. As key points for this sensitivity my suggestions would be:

  • Very complex network (approx. 2 million parameters) provides best results, but more parameters do not automatically produce a better fit
  • Network architecture matters! Wide, shallow networks seems to be better than narrow, deep networks.
  • Convolutions could provide a way of reducing the number of parameters (my results do not show this, but I think we need to test this further)
  • Similar results for mean absolute error and log RMSE minimization.
In [399]:
# Drop log loss for clarity
df_no_log = df.drop(1)
In [402]:
df_no_log['Experiment']
Out[402]:
0    1month_distr_shuffle
2        1month_miniscule
3             1month_tiny
4            1month_small
5           1month_medium
6      1month_narrow_deep
7    1month_broad_shallow
8             1month_conv
9       1month_local_conv
Name: Experiment, dtype: object
In [436]:
# Chose colors
colors = ['magenta', '#fc9272', '#cb181d', '#74c476', '#238b45', 
          '#00441b', '#67000d', '#6baed6', '#08519c']
In [461]:
# plot for the talk
sns.set_style('white')
plt.rcParams['xtick.color'] = 'white'
plt.rcParams['ytick.color'] = 'white'
fig, ax = plt.subplots(1, 1, figsize=(6, 5), facecolor='black')
ax.set_facecolor('whitesmoke')
plt.scatter(df_no_log['Parameters'], df_no_log['Explained variance'], c=colors, s=150)
plt.xscale('log')
plt.xlabel('Number of parameters', color='white')
plt.ylabel('Explained variance', color='white')
plt.title('Sensitivity to network complexity and architecture', color='white');
In [462]:
fig.savefig('/home/s/S.Rasp/tmp/agu_presentation/agu_network_sensitivity.png', dpi=300)

Here you can add some labels in the presentation. Maybe just arrows in the respective colors with a description:

  • Magenta is the complex, "best" network
  • In green are the narrow, deep networks.
  • In red are the shallow, wide networks
  • In blue are the convolutional networks Narrow deep means more hidden layers with less nodes per layer. Shallow wide means only one hidden layer with more nodes.

My CNN architectures

A note about my convolutional networks. The architecture is slightly different than Pierre's because I am predicting SPD*, QR*, PRECT and FLUT at the same time, see graphs and explanation below.

In [17]:
model = keras.models.load_model('./models/019_1month_conv.h5')
SVG(model_to_dot(model, show_shapes=True).create(prog='dot', format='svg'))
Out[17]:
G 140080315169144 input_1: InputLayer input: output: (None, 21, 7) (None, 21, 7) 140080315167296 conv1d_1: Conv1D input: output: (None, 21, 7) (None, 21, 32) 140080315169144->140080315167296 140080315168920 conv1d_2: Conv1D input: output: (None, 21, 32) (None, 21, 32) 140080315167296->140080315168920 140080313936640 conv1d_3: Conv1D input: output: (None, 21, 32) (None, 21, 32) 140080315168920->140080313936640 140080313937312 conv1d_4: Conv1D input: output: (None, 21, 32) (None, 21, 32) 140080313936640->140080313937312 140080313937592 conv1d_5: Conv1D input: output: (None, 21, 32) (None, 21, 32) 140080313937312->140080313937592 140079623144000 conv1d_6: Conv1D input: output: (None, 21, 32) (None, 21, 32) 140080313937592->140079623144000 140080315168360 conv1d_7: Conv1D input: output: (None, 21, 32) (None, 21, 6) 140079623144000->140080315168360 140080314868400 flatten_1: Flatten input: output: (None, 21, 6) (None, 126) 140080315168360->140080314868400 140080314870640 dense_1: Dense input: output: (None, 126) (None, 86) 140080314868400->140080314870640

I think the architecture mirrors the TF implementation until the last Conv1D layer. At this point in TF there is a final convolution layer with two channels which are then taken directly as the output (e.g. SPDT and SPDQ). In my case I am using a convolution layer with 6 channels, followed by a flattening operation and a final dense layer. This final dense layer makes up around half of the parameters of the model. Next up the locally connected model.

In [20]:
model = keras.models.load_model('./models/020_1month_local_conv.h5')
SVG(model_to_dot(model, show_shapes=True).create(prog='dot', format='svg'))
Out[20]:
G 140080314075064 input_1: InputLayer input: output: (None, 21, 7) (None, 21, 7) 140080314074840 locally_connected1d_1: LocallyConnected1D input: output: (None, 21, 7) (None, 19, 20) 140080314075064->140080314074840 140077558875416 locally_connected1d_2: LocallyConnected1D input: output: (None, 19, 20) (None, 17, 20) 140080314074840->140077558875416 140080314897184 locally_connected1d_3: LocallyConnected1D input: output: (None, 17, 20) (None, 15, 6) 140077558875416->140080314897184 140080314896624 flatten_1: Flatten input: output: (None, 15, 6) (None, 90) 140080314897184->140080314896624 140080314899872 dense_1: Dense input: output: (None, 90) (None, 86) 140080314896624->140080314899872

In my quick test a model with one dense layer provides better results with a similar number of parameters. Note also that convolution operations take more time than a the matrix multiplication in a dense layer. But of course, my architecture is different from the TF implementation, so that is something to explore again later.

Sensitivity to amount of training data

For this we are comparing the following training sets:

  • 1 month continuous
  • days 1-3 from each month (roughly comparable amount of data)
  • 3 month continuous
  • 1 year
In [27]:
model_names = [
    '024_1month_cont_shuffle',
    '005_1month_distr_shuffle',
    '026_2month_cont_shuffle',
    '025_3month_cont_shuffle',
    '027_6month_cont_shuffle',
    '002_1year_shuffle'
]
In [ ]:
# # Compute scores for all models... running in iPython notebook
# score_list = []
# for model in model_names:
#     score_list.append(run_diagnostics(
#     './models/' + model + '.h5', 
#     data_dir,
#     'valid_year1', 
#     data_dir + 'year1_norm.nc',
#     convo=True if 'conv' in model else False,
# ))
In [ ]:
# df = pd.DataFrame({
#     'Experiment': [n[4:] for n in model_names], 
#     'Explained variance': [np.mean(1-v[0]) for v in score_list],
# })
# df.to_csv('./training_sensitivity.csv')
In [449]:
df2 = pd.read_csv('./training_sensitivity.csv').drop(1)   # Drop day1-3
In [450]:
df2
Out[450]:
Unnamed: 0 Experiment Explained variance Months
0 0 1month_cont_shuffle 0.667127 1
2 2 3month_cont_shuffle 0.689924 3
3 3 1year_shuffle 0.700110 12
4 0 6month_cont_shuffle 0.696863 6
5 1 2month_cont_shuffle 0.682665 2
In [451]:
sns.barplot(y='Experiment', x='Explained variance', data=df2, palette='hls')
plt.tight_layout()
In [466]:
# plot for the talk
sns.set_style('white')
plt.rcParams['xtick.color'] = 'white'
plt.rcParams['ytick.color'] = 'white'
fig, ax = plt.subplots(1, 1, figsize=(6, 5), facecolor='black')
ax.set_facecolor('whitesmoke')
plt.scatter(df2['Months'], df2['Explained variance'], s=150, c='#cb181d')
plt.xlabel('Months of training data', color='white')
plt.ylabel('Explained variance', color='white')
plt.xlim(0, 13)
plt.ylim(0.66, 0.71)
plt.title('Sensitivity to amount of training data', color='white');
In [467]:
fig.savefig('/home/s/S.Rasp/tmp/agu_presentation/agu_data_sensitivity.png', dpi=300)

For the talk I would suggest a plot like tho one above showing the amount of training data against the score. As key points, I would suggest:

  • More training data gives a better validation fit but the increase levels off after 6 months
  • In out Aquaplanet data, it makes little difference whether the data is continuous from one month or distributed across seasons. Not sure how this would generalize to real globe with continents.
  • This could motivate short (approx. 6 months), expensive runs.

More detailed analysis of network predictions

For all further analysis we will use the "best" model which is the 1024-1024-512-512 model with the MAE minimization. Everything below is done using the validation data that the network has not seen during training!

Plot random samples

These are some random samples showing the input and output profiles. Not sure if you want that, but if you have time it could illustrate the task of the network!

In [33]:
# Load some data
target_file = h5py.File(data_dir + 'valid_year1_targets.nc', 'r')
feature_file = h5py.File(data_dir + 'valid_year1_features.nc', 'r')
In [34]:
# Dimensions to be able to reshape array
n_lon = 128; n_lat = 64; n_geo = n_lat * n_lon
In [36]:
# Load the first 3 days
sample_targets = target_file['targets'][:3*n_geo, :]
sample_features = feature_file['features'][:3*n_geo, :]
In [71]:
# Load the best model
model = keras.models.load_model('./models/002_1year_shuffle.h5')
In [39]:
# Get predictions
sample_preds = model.predict(sample_features)
In [ ]:
plt.rcParams['xtick.color'] = 'white'
plt.rcParams['ytick.color'] = 'white'
fig, ax = plt.subplots(1, 1, figsize=(6, 5), facecolor='black')
ax.set_facecolor('whitesmoke')
plt.scatter(df2['Months'], df2['Explained variance'], s=150, c='#cb181d')
plt.xlabel('Months of training data', color='white')
plt.ylabel('Explained variance', color='white')
plt.xlim(0, 13)
plt.ylim(0.66, 0.71)
plt.title('Sensitivity to amount of training data', color='white');
In [477]:
unitdict = {
    'TAP': 'K',
    'QAP': 'kg/kg',
    'dTdt_adiabatic': 'K/s',
    'dQdt_adiabatic': 'kg/kg/s',
    'SHFLX': r'W/m$^2$',
    'LHFLX': r'W/m$^2$',
    'SOLIN': r'W/m$^2$',
    'SPDT': 'K/s',
    'SPDQ': 'kg/kg/s',
    'QRL': 'K/s',
    'QRS': 'K/s',
    'PRECT': 'mm/h',
    'FLUT': r'W/m$^2$',
}
In [632]:
rangedict = {
    'TAP': (180, 310),
    'QAP': (0, 0.020),
    'dTdt_adiabatic': (-4e-4, 4e-4),
    'dQdt_adiabatic': (-2e-7, 2e-7),
    'SHFLX': (-10, 150),
    'LHFLX': (-10, 150),
    'SOLIN': (0, 1500),
    'SPDT': (-5e-4, 5e-4),
    'SPDQ': (-5e-7, 5e-7),
    'QRL': (-2e-4, 2e-4),
    'QRS': (0, 1e-4),
    'PRECT': (0, 10),
    'FLUT': (0, 300),
}
In [660]:
feature_c = '#08519c'
target_c = '#238b45'
pred_c = '#cb181d'
def vis_features_targets_from_pred2(features, targets,
                                    predictions, sample_idx,
                                    feature_names, target_names,
                                    unscale_targets=False,
                                    save=None):
    """NOTE: FEATURES HARD-CODED!!!
    Features are [TAP, QAP, dTdt_adiabatic, dQdt_adiabatic, SHFLX, LHFLX, SOLIN]
    Targets are [SPDT, SPDQ, QRL, QRS, PRECT, FLUT]
    """
    nz = 21
    z = np.arange(20, -1, -1)
    fig, axes = plt.subplots(2, 5, figsize=(12.5, 10), facecolor='black')
    in_axes = np.ravel(axes[0, :])
    out_axes = np.ravel(axes[1, :])

    for i in range(len(feature_names[:-3])):
        in_axes[i].plot(features[sample_idx, i*nz:(i+1)*nz], z, c=feature_c, lw=3)
        in_axes[i].set_title(feature_names[i], color='white')
        in_axes[i].set_yticks([0, 4, 9, 14, 19])
        in_axes[i].set_yticklabels([1, 5, 10, 15, 20])
        in_axes[i].set_xlabel(unitdict[feature_names[i]], color='white')
        in_axes[i].set_xlim(rangedict[feature_names[i]])
        in_axes[i].axvline(0, c='gray', zorder=0.1, linewidth=0.7)
        in_axes[i].ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    twin_in = in_axes[-1].twinx()
    in_axes[-1].bar([1, 2], features[sample_idx, -3:-1], color=feature_c)
    in_axes[-1].set_ylim(rangedict['SHFLX'])
    twin_in.bar([3], features[sample_idx, -1], color=feature_c)
    twin_in.set_ylim(rangedict['SOLIN'])
    in_axes[-1].set_xticks([1, 2, 3])
    in_axes[-1].set_xticklabels(['SHFLX\n[W/m2]', 'LHFLX\n[W/m2]', 'SOLIN\n[W/m2]'])
    in_axes[0].set_ylabel('Model level', color='white')

    for i in range(len(target_names[:-2])):
        if unscale_targets:
            u = conversion_dict[target_names[i]]
        else:
            u = 1.
        out_axes[i].plot(targets[sample_idx, i * nz:(i + 1) * nz] / u, z,
                         label='SPCAM', c=target_c, lw=3)
        out_axes[i].plot(predictions[sample_idx, i * nz:(i + 1) * nz] / u, z,
                         label='CBRAIN', c=pred_c, lw=3)
        out_axes[i].set_title(target_names[i], color='white')
        out_axes[i].set_yticks([0, 4, 9, 14, 19])
        out_axes[i].set_yticklabels([1, 5, 10, 15, 20])
        out_axes[i].set_xlabel(unitdict[target_names[i]], color='white')
        out_axes[i].set_xlim(rangedict[target_names[i]])
        out_axes[i].axvline(0, c='gray', zorder=0.1, linewidth=0.7)
        out_axes[i].ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    twin_out = out_axes[-1].twinx()
    out_axes[-1].bar(1 - 0.2, targets[sample_idx, -2] / conversion_dict[target_names[-2]]* 1e3 * 3600, 0.4,
                     color=target_c)
    out_axes[-1].bar(1 + 0.2, predictions[sample_idx, -2]/ conversion_dict[target_names[-2]]* 1e3 * 3600,
                     0.4, color=pred_c)
    out_axes[-1].set_ylim(rangedict['PRECT'])
    twin_out.bar(2 - 0.2, targets[sample_idx, -1]/ conversion_dict[target_names[-1]], 0.4,
                     color=target_c)
    twin_out.bar(2 + 0.2, predictions[sample_idx, -1]/ conversion_dict[target_names[-1]],
                     0.4, color=pred_c)
    twin_out.set_ylim(rangedict['FLUT'])
    out_axes[-1].set_xticks([1, 2])
    out_axes[-1].set_xticklabels(['PREC\n[mm/h]', 'OLR\n[W/m2]'])
    out_axes[0].legend(loc=2, frameon=True, framealpha=1)
    out_axes[0].set_ylabel('Model level', color='white')
    #plt.suptitle('Sample %i' % sample_idx, fontsize=15)
    fig.subplots_adjust(bottom=0.15, wspace=0.25, hspace=0.30, left=0.1, right=0.95, top=0.95)
    for ax in list(in_axes) + list(out_axes):
        ax.set_facecolor('whitesmoke')
    if save == None:
        plt.show()
    else:
        fig.savefig(save, dpi=300)
In [659]:
# Good indices for showing network task
good_indices = [
    4020,   # Strong convection
    7300,   # Longwave cooling at cloud top
    12206,
    1585,
    3943,
    19457,
    5675
]
In [661]:
for j, i in enumerate(good_indices):
    print('Sample', i)
    vis_features_targets_from_pred2(
        sample_features * norm.variables['feature_stds'] + norm.variables['feature_means'], 
        sample_targets, 
        sample_preds,
        sample_idx=i, 
        feature_names=feature_names, 
        target_names=target_names,
        unscale_targets=True,
        save='/home/s/S.Rasp/tmp/agu_presentation/agu_samples_' + str(j).zfill(2) +  '.png'
    )
Sample 4020
Sample 7300
Sample 12206
Sample 1585
Sample 3943
Sample 19457
Sample 5675

I think these plots are quite nice because they really show what the network should do and what it does do. But of course this is no statistical analysis. Maybe this could be linked somehow to slide 6? If you would like these plots I would make them nicer and label them properly.

Lat-lev slices of SPDT/SPDQ and QRL/QRS

Lots of ugly code coming up. Scroll to the bottom for my summary. I attached the animations in the email!

In [ ]:
# Define conversion dict
L_V = 2.5e6   # Latent heat of vaporization is actually 2.26e6
C_P = 1e3 # Specific heat capacity of air at constant pressure
conversion_dict = {
    'SPDT': C_P,
    'SPDQ': L_V,
    'QRL': C_P,
    'QRS': C_P,
    'PRECT': 1e3*24*3600 * 1e-3,
    'FLUT': 1. * 1e-5,
}
In [662]:
sns.set_style("white")
from matplotlib.animation import FuncAnimation 
plt.rcParams['axes.facecolor']='black'
plt.rcParams['savefig.facecolor']='black'
plt.rcParams['text.color'] = 'white'
def plot_yz(targets, preds, itime, ilon, var='SP', anim=False, model=None, n_steps=20, 
            interval=150):
    """What an awefully complicated function... sigh"""
    # Reshape
    targets = targets.reshape((-1, n_lat, n_lon, targets.shape[-1]))
    preds = preds.reshape((-1, n_lat, n_lon, preds.shape[-1]))
    
    # Get indices
    tmp_idxs = np.arange(4*21)
    if var == 'SP':
        idxs = [tmp_idxs[:21], tmp_idxs[21:42]]
        names = ['SPDT', 'SPDQ']
        ranges = [[-5e-4, 5e-4], [-5e-7, 5e-7]]
        unit = ['K/s', 'kg/kg/s']
    elif var == 'QR':
        idxs = [tmp_idxs[42:63], tmp_idxs[63:84]]
        names = ['QRL', 'QRS']
        ranges = [[-2e-4, 2e-4], [-1.2e-4, 1.2e-4]]
        unit = ['K/s', 'K/s']
    
    # Unscale variables
    u = [conversion_dict[n] for n in names]
    
    # Plot
    fig, axes = plt.subplots(3, 2, figsize=(12.5,8.5), sharex=True, sharey=True,
                            facecolor='black')
    range_spdt = np.max(np.abs(targets[itime, :, ilon, idxs[0]] / u[0]))
    range_spdq = np.max(np.abs(targets[itime, :, ilon, idxs[1]] / u[1]))
    # Target SPDT
    I00 = axes[0, 0].imshow(targets[itime, :, ilon, idxs[0]] / u[0],
                       vmin=ranges[0][0], vmax=ranges[0][1], cmap='bwr')
    axes[0, 0].set_title('SPCAM ' + names[0])
    #fig.colorbar(I, ax=axes[0, 0], shrink=0.4)
    # Pred SPDT
    I10 = axes[1, 0].imshow(preds[itime, :, ilon, idxs[0]] / u[0],
                       vmin=ranges[0][0], vmax=ranges[0][1], cmap='bwr')
    axes[1, 0].set_title('CLOUDBRAIN ' + names[0])
    #fig.colorbar(I, ax=axes[1, 0], shrink=0.4)
    # Error 1
    e = preds[itime, :, ilon, idxs[0]] / u[0] - targets[itime, :, ilon, idxs[0]] / u[0]
    I20 = axes[2, 0].imshow(e,
                       vmin=ranges[0][0], vmax=ranges[0][1], cmap='bwr')
    axes[2, 0].set_title('Difference ' + names[0])
    cbar_ax_left = fig.add_axes([0.14, 0.1, 0.25, 0.02])
    cb_left = fig.colorbar(I20, cax=cbar_ax_left, orientation='horizontal', 
                            ticks=[ranges[0][0], 0, ranges[0][1]], format='%.0e')
    cb_left.set_label(unit[0], color='white')
    # Target SPDQ
    I01 = axes[0, 1].imshow(targets[itime, :, ilon, idxs[1]] / u[1],
                       vmin=ranges[1][0], vmax=ranges[1][1], cmap='bwr')
    axes[0, 1].set_title('SPCAM ' + names[1], color='white')
    #fig.colorbar(I, ax=axes[0, 1], shrink=0.4)
    # Pred SPDQ
    I11 = axes[1, 1].imshow(preds[itime, :, ilon, idxs[1]] / u[1],
                       vmin=ranges[1][0], vmax=ranges[1][1], cmap='bwr')
    axes[1, 1].set_title('CLOUDBRAIN ' + names[1])
    #fig.colorbar(I, ax=axes[1, 1], shrink=0.4)
    # Error 2
    e = preds[itime, :, ilon, idxs[1]] / u[1] - targets[itime, :, ilon, idxs[1]] / u[1]
    I21 = axes[2, 1].imshow(e,
                       vmin=ranges[1][0], vmax=ranges[1][1], cmap='bwr')
    axes[2, 1].set_title('Difference ' + names[1])
    fig.suptitle('Day: %i - Hour: %.1f' % (0, 0), fontsize=17)
    fig.subplots_adjust(bottom=0.15, wspace=0.1, hspace=0.25, left=0.05, right=0.95, top=0.9)
    cbar_ax_right = fig.add_axes([0.62, 0.1, 0.25, 0.02])
    cb_right = fig.colorbar(I21, cax=cbar_ax_right, orientation='horizontal', 
                            ticks=[ranges[1][0], 0, ranges[1][1]], format='%.0e')
    cb_right.set_label(unit[1], color='white')
    cb_right.ax.tick_params(axis='x', colors='white')
    cb_left.ax.tick_params(axis='x', colors='white')
    
    for ax in list(np.ravel(axes)):
        ax.set_xticks([])
        ax.set_yticks([])
    sns.despine(left=True, bottom=True)
    #plt.tight_layout()
    if anim:
        gen_obj = DataGenerator(
            data_dir,
            'valid_year1' + '_features.nc',
            'valid_year1' + '_targets.nc',
            shuffle=False,
            batch_size=n_geo,
            verbose=False,
        )
        gen = gen_obj.return_generator()
        def update(i):
            tmp_features, tmp_targets = next(gen)
            # Get predictions
            tmp_preds = model.predict_on_batch(tmp_features)
            targets = tmp_targets.reshape((-1, n_lat, n_lon, tmp_targets.shape[-1]))
            preds = tmp_preds.reshape((-1, n_lat, n_lon, tmp_preds.shape[-1]))
            I00.set_data(targets[itime, :, ilon, idxs[0]] / u[0])
            I10.set_data(preds[itime, :, ilon, idxs[0]] / u[0])
            I20.set_data(preds[itime, :, ilon, idxs[0]] / u[0] - 
                         targets[itime, :, ilon, idxs[0]] / u[0])
            I01.set_data(targets[itime, :, ilon, idxs[1]] / u[1])
            I11.set_data(preds[itime, :, ilon, idxs[1]] / u[1])
            I21.set_data(preds[itime, :, ilon, idxs[1]] / u[1] - 
                         targets[itime, :, ilon, idxs[1]] / u[1])
            hour = i % 48 / 2
            day = (i - hour) / 48
            fig.suptitle('Day: %i - Hour: %.1f' % (day, hour), fontsize=17)
            return I00, I10, I20, I01, I11, I21
        return FuncAnimation(fig, update, frames=np.arange(n_steps), interval=interval, blit=True)
    else:
        plt.show()
In [663]:
model = keras.models.load_model('./models/002_1year_shuffle.h5')
anim = plot_yz(sample_targets, sample_preds, 0, 10, 'SP', anim=True, model=model, n_steps=48 * 5)
In [282]:
HTML(anim.to_html5_video())
Out[282]:
In [664]:
anim.save('/home/s/S.Rasp/tmp/SP_lat_lev_v1.mp4', dpi=80)
In [665]:
model = keras.models.load_model('./models/002_1year_shuffle.h5')
anim = plot_yz(sample_targets, sample_preds, 0, 10, 'QR', anim=True, model=model, n_steps=48 * 5)
In [285]:
HTML(anim.to_html5_video())
Out[285]:
In [666]:
anim.save('/home/s/S.Rasp/tmp/QR_lat_lev_v1.mp4', dpi=80)

Lat-lon plots of PRECT and FLUT

In [667]:
sns.set_style("white")
from matplotlib.animation import FuncAnimation 
plt.rcParams['axes.facecolor']='black'
plt.rcParams['savefig.facecolor']='black'
plt.rcParams['text.color'] = 'white'
def plot_xy(targets, preds, itime, anim=False, model=None, n_steps=20, 
            interval=150):
    """What an awefully complicated function... part 2"""
    # Reshape
    targets = targets.reshape((-1, n_lat, n_lon, targets.shape[-1]))
    preds = preds.reshape((-1, n_lat, n_lon, preds.shape[-1]))
    
    # Get indices
    names = ['PREC', 'OLR']
    ranges = [[0, 7], [0, 300]]
    diff_ranges = [[-2, 2], [-80, 80]]
    unit = ['mm/h', r'W/m$^2$']
    
    # Unscale variables
    u = [conversion_dict[n] for n in ['PRECT', 'FLUT']]
    u[0] = u[0] / 1e3 / 3600
    # Plot
    fig, axes = plt.subplots(3, 2, figsize=(12.5,8.5), sharex=True, sharey=True,
                            facecolor='black')
    # Target SPDT
    I00 = axes[0, 0].imshow(targets[itime, :, :, -2] / u[0],
                       vmin=ranges[0][0], vmax=ranges[0][1], cmap='Blues_r')
    axes[0, 0].set_title('SPCAM ' + names[0])
    cbar_ax_left1 = fig.add_axes([0.435, 0.55, 0.01, 0.15])
    cb_left1 = fig.colorbar(I00, cax=cbar_ax_left1, orientation='vertical', 
                            ticks=[ranges[0][0], 0, ranges[0][1]], format='%.0e')
    cb_left1.set_label(unit[0], color='white')
    cb_left1.ax.set_yticklabels([ranges[0][0], 0, ranges[0][1]],rotation=90, color='w')
    # Pred SPDT
    I10 = axes[1, 0].imshow(preds[itime, :, :, -2] / u[0],
                       vmin=ranges[0][0], vmax=ranges[0][1], cmap='Blues_r')
    axes[1, 0].set_title('CLOUDBRAIN ' + names[0])
    #fig.colorbar(I, ax=axes[1, 0], shrink=0.4)
    # Error 1
    e = preds[itime, :, :, -2] / u[0] - targets[itime, :, :, -2] / u[0]
    I20 = axes[2, 0].imshow(e,
                       vmin=diff_ranges[0][0], vmax=diff_ranges[0][1], cmap='bwr')
    axes[2, 0].set_title('Difference ' + names[0])
    cbar_ax_left = fig.add_axes([0.435, 0.1, 0.01, 0.15])
    cb_left = fig.colorbar(I20, cax=cbar_ax_left, orientation='vertical', 
                            ticks=[diff_ranges[0][0], 0, diff_ranges[0][1]], format='%.0f')
    cb_left.set_label(unit[0], color='white')
    cb_left.ax.set_yticklabels([diff_ranges[0][0], 0, diff_ranges[0][1]],rotation=90, color='w')
    # Target SPDQ
    I01 = axes[0, 1].imshow(targets[itime, :, :, -1] / u[1],
                       vmin=ranges[1][0], vmax=ranges[1][1], cmap='viridis')
    axes[0, 1].set_title('SPCAM ' + names[1], color='white')
    #fig.colorbar(I, ax=axes[0, 1], shrink=0.4)
    # Pred SPDQ
    I11 = axes[1, 1].imshow(preds[itime, :, :, -1] / u[1],
                       vmin=ranges[1][0], vmax=ranges[1][1], cmap='viridis')
    axes[1, 1].set_title('CLOUDBRAIN ' + names[1])
    cbar_ax_right1 = fig.add_axes([0.86, 0.55, 0.01, 0.15])
    cb_right1 = fig.colorbar(I01, cax=cbar_ax_right1, orientation='vertical', 
                            ticks=[ranges[1][0], 0, ranges[1][1]], format='%.0e')
    cb_right1.set_label(unit[1], color='white')
    cb_right1.ax.set_yticklabels([ranges[1][0], 0, ranges[1][1]],rotation=90, color='w')
    # Error 2
    e = preds[itime, :, :, -1] / u[1] - targets[itime, :, :, -1] / u[1]
    I21 = axes[2, 1].imshow(e,
                       vmin=diff_ranges[1][0], vmax=diff_ranges[1][1], cmap='bwr')
    axes[2, 1].set_title('Difference ' + names[1])
    fig.suptitle('Day: %i - Hour: %.1f' % (0, 0), fontsize=17)
    fig.subplots_adjust(bottom=0.05, wspace=0.12, hspace=0.25, left=0.05, right=0.85, top=0.9)
    cbar_ax_right = fig.add_axes([0.86, 0.1, 0.01, 0.15])
    cb_right = fig.colorbar(I21, cax=cbar_ax_right, orientation='vertical', 
                            ticks=[diff_ranges[1][0], 0, diff_ranges[1][1]], format='%.0e')
    cb_right.set_label(unit[1], color='white')
    cb_right.ax.set_yticklabels([diff_ranges[1][0], 0, diff_ranges[1][1]],rotation=90, color='w')
    cb_right.ax.tick_params(axis='x', colors='white')
    cb_left.ax.tick_params(axis='x', colors='white')
    
    for ax in list(np.ravel(axes)):
        ax.set_xticks([])
        ax.set_yticks([])
    sns.despine(left=True, bottom=True)
    #plt.tight_layout()
    if anim:
        gen_obj = DataGenerator(
            data_dir,
            'valid_year1' + '_features.nc',
            'valid_year1' + '_targets.nc',
            shuffle=False,
            batch_size=n_geo,
            verbose=False,
        )
        gen = gen_obj.return_generator()
        def update(i):
            tmp_features, tmp_targets = next(gen)
            # Get predictions
            tmp_preds = model.predict_on_batch(tmp_features)
            targets = tmp_targets.reshape((-1, n_lat, n_lon, tmp_targets.shape[-1]))
            preds = tmp_preds.reshape((-1, n_lat, n_lon, tmp_preds.shape[-1]))
            I00.set_data(targets[itime, :, :, -2] / u[0])
            I10.set_data(preds[itime, :, :, -2] / u[0])
            I20.set_data(preds[itime, :, :, -2] / u[0] - 
                         targets[itime, :, :, -2] / u[0])
            I01.set_data(targets[itime, :, :,-1] / u[1])
            I11.set_data(preds[itime, :, :, -1] / u[1])
            I21.set_data(preds[itime, :, :, -1] / u[1] - 
                         targets[itime, :, :, -1] / u[1])
            hour = i % 48 / 2
            day = (i - hour) / 48
            fig.suptitle('Day: %i - Hour: %.1f' % (day, hour), fontsize=17)
            return I00, I10, I20, I01, I11, I21
        return FuncAnimation(fig, update, frames=np.arange(n_steps), interval=interval, blit=True)
    else:
        plt.show()
In [668]:
model = keras.models.load_model('./models/002_1year_shuffle.h5')
anim = plot_xy(sample_targets, sample_preds, 0, anim=True, model=model, n_steps=48 * 10,
              interval=75)
In [377]:
HTML(anim.to_html5_video())
Out[377]:
In [669]:
anim.save('/home/s/S.Rasp/tmp/PREC_OLR_lat_lon_v1.mp4', dpi=80)

Summary of visualizations

Here are some things I noticed looking at the animations.

  • Lack of variability. This is visible in the SP slices, where the values are generally to low, and the precipitation plots where the variability in the sub-tropics is missing. This makes sense since the network tries to fit a smoothed version. Maybe mention further work: trying to fit variance or GANs (Generative adverserial models, something I would love to try out this spring)
  • Heating rates are very good.